#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datetime import datetime

import numpy as np
import pandas as pd
from django.db.models import Q
from matplotlib import pyplot as plt
from statsmodels.tsa.arima.model import ARIMA

from web.constants.datetime_format import DatetimeFormat
from web.manager.log_manager import LogManager
from web.models.stock_index_week import StockIndexWeek
from web.task.base_task import BaseTask
from web.util.datetime_util import DatetimeUtil

Logger = LogManager.get_logger(__name__)


class ArimaTask(BaseTask):
    """
    ARIMA（自回归积分滑动平均模型）
    参考资料：https://blog.csdn.net/sixpp/article/details/142679195
    """

    def do_task(self, code: str, begin_date: str, end_date: str):
        Logger.info('开始执行arima算法')

        # 准备数据
        begin_date: datetime.date = DatetimeUtil.str_to_datetime(begin_date, DatetimeFormat.Date_Format)
        end_date: datetime.date = DatetimeUtil.str_to_datetime(end_date, DatetimeFormat.Date_Format)
        stock_index_week_queryset = StockIndexWeek.objects.filter(
            Q(end_date__range=[begin_date, end_date]) & Q(code_=code)).order_by('end_date')

        if stock_index_week_queryset is not None and len(stock_index_week_queryset) > 0:
            # queryset转换为dataframe
            stock_index_week_dataframe = pd.DataFrame(list(stock_index_week_queryset.values()))

            # 创建直方图
            close_price_list = list(stock_index_week_dataframe['close_price'])
            close_price_ndarray = np.array(close_price_list).astype('float')

            # 训练 ARIMA 模型
            arima_model = ARIMA(close_price_ndarray, order=(5, 1, 0))
            arima_result = arima_model.fit()

            # 预测未来趋势
            forecast = arima_result.forecast(steps=300)

            # 绘图
            plt.figure(figsize=(14, 7))
            plt.plot(stock_index_week_dataframe['close_price'], label='Close Price')
            plt.plot(forecast, label='ARIMA Forecast', color='orange')
            plt.title('ARIMA Stock Price Forecast')
            plt.xlabel('Date')
            plt.ylabel('Price')
            plt.legend()
            plt.show()
            pass


if __name__ == '__main__':
    arima_task = ArimaTask()
    arima_task.do_task("000001", "20010101", "20150630")
